{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ef0a4ee4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\oscar'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "japanese-margin",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\oscar'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "executive-insulin",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('C:\\\\Users\\\\oscar\\\\Documents\\\\experimentos\\\\2025_Four_carnivores_MAML')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0888b480",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\oscar\\\\Documents\\\\experimentos\\\\2025_Four_carnivores_MAML'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "130fad96",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FIRST, WE IMPLEMENT A COLOR AUGMENTATION FUNCTION\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "#from tensorflow.keras.applications.vgg19 import preprocess_input #224 x 224\n",
    "from tensorflow.keras.applications.resnet50 import preprocess_input #224 x 224\n",
    "#from tensorflow.keras.applications.densenet import preprocess_input #224 x 224\n",
    "#from tensorflow.keras.applications.efficientnet_v2 import preprocess_input #299 x 299 or 300 x 300\n",
    "#from tensorflow.keras.applications.xception import preprocess_input #299 x 299\n",
    "import cv2\n",
    "\n",
    "def random_hue_brightness_saturation(image):\n",
    "    # Convert the image to HSV\n",
    "    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)\n",
    "    \n",
    "    # Randomly choose which adjustment to apply\n",
    "    random_value = np.random.randint(0, 3)\n",
    "    \n",
    "    if random_value == 0:\n",
    "        # Randomly adjust hue\n",
    "        hue_shift = np.random.uniform(-10, 10)\n",
    "        hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue_shift) % 180\n",
    "    elif random_value == 1:\n",
    "        # Randomly adjust saturation\n",
    "        saturation_scale = np.random.uniform(0.8, 1.2)\n",
    "        hsv_image[:, :, 1] = np.clip(hsv_image[:, :, 1] * saturation_scale, 0, 255)\n",
    "    else:\n",
    "        # Randomly adjust brightness\n",
    "        brightness_scale = np.random.uniform(0.8, 1.2)\n",
    "        hsv_image[:, :, 2] = np.clip(hsv_image[:, :, 2] * brightness_scale, 0, 255)\n",
    "    \n",
    "    # Convert back to RGB\n",
    "    augmented_image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB)\n",
    "    \n",
    "    # Apply ResNet50 preprocessing\n",
    "    augmented_image = preprocess_input(augmented_image)\n",
    "    \n",
    "    return augmented_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d239a31-5db2-453e-a940-bbf505352c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, models, optimizers\n",
    "#from tensorflow.keras.applications import densenet\n",
    "#from tensorflow.keras.applications import resnet50\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "import os\n",
    "import cv2\n",
    "from glob import glob\n",
    "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
    "\n",
    "# Set your directory containing the images\n",
    "data_directory = 'C:\\\\Users\\\\oscar\\\\Documents\\\\experimentos\\\\2025_Four_carnivores_MAML\\\\images'\n",
    "\n",
    "# Function to load and resize images from a directory\n",
    "# Import the augmentation function\n",
    "\n",
    "def load_images(directory, target_size=(224, 224)):# change accordinmg to TL model\n",
    "    image_list = []\n",
    "    file_pattern = os.path.join(directory, '*.bmp')  # Modify as needed for other file types\n",
    "    for filename in glob(file_pattern):\n",
    "        print(f\"Loading, resizing, and preprocessing image: {filename}\")\n",
    "        \n",
    "        # Load image and resize\n",
    "        img = load_img(filename, target_size=target_size)\n",
    "        img_array = img_to_array(img)\n",
    "\n",
    "        # Convert to OpenCV format (uint8 required for cv2 functions)\n",
    "        img_array = img_array.astype(np.uint8)\n",
    "        \n",
    "        # Apply color augmentation\n",
    "        img_array = random_hue_brightness_saturation(img_array)\n",
    "        \n",
    "        image_list.append(img_array)\n",
    "\n",
    "    return np.array(image_list)\n",
    "\n",
    "\n",
    "# Load images for each class\n",
    "class1_images = load_images(os.path.join(data_directory, 'Crocodiles'))\n",
    "class2_images = load_images(os.path.join(data_directory, 'Hyenas'))\n",
    "class3_images = load_images(os.path.join(data_directory, 'Leopards'))\n",
    "class4_images = load_images(os.path.join(data_directory, 'Lions'))\n",
    "\n",
    "# Create labels for each class\n",
    "class1_labels = np.zeros(len(class1_images))\n",
    "class2_labels = np.ones(len(class2_images))\n",
    "class3_labels = 2 * np.ones(len(class3_images))\n",
    "class4_labels = 3 * np.ones(len(class4_images))\n",
    "\n",
    "# Concatenate images and labels for all classes\n",
    "all_images = np.concatenate([class1_images, class2_images, class3_images, class4_images])\n",
    "all_labels = np.concatenate([class1_labels, class2_labels, class3_labels, class4_labels])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e6276142",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the data into training (70%) and the rest (30%)\n",
    "train_images, temp_images, train_labels, temp_labels = train_test_split(\n",
    "    all_images, all_labels, test_size=0.3, random_state=42\n",
    ")\n",
    "\n",
    "# Split the remaining 30% into validation (15%) and testing (15%)\n",
    "val_images, test_images, val_labels, test_labels = train_test_split(\n",
    "    temp_images, temp_labels, test_size=0.5, random_state=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "61c8eadf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training images: 907\n",
      "Number of validation images: 194\n",
      "Number of testing images: 195\n"
     ]
    }
   ],
   "source": [
    "num_training_images = len(train_images)\n",
    "print(f'Number of training images: {num_training_images}')\n",
    "\n",
    "num_validation_images = len(val_images)\n",
    "print(f'Number of validation images: {num_validation_images}')\n",
    "\n",
    "num_testing_images = len(test_images)\n",
    "print(f'Number of testing images: {num_testing_images}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1dcb1c6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1296, 299, 299, 3)\n",
      "(1296,)\n",
      "(array([0., 1., 2., 3.]), array([124, 364, 544, 264], dtype=int64))\n"
     ]
    }
   ],
   "source": [
    "print(all_images.shape)  # Should be (Total Images, 224, 224, 3)\n",
    "print(all_labels.shape)  # Should be (Total Images,)\n",
    "print(np.unique(all_labels, return_counts=True))  # Check label distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac59971",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab737eb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#NOW APPLY MODELS:\n",
    "\n",
    "# with regularization: Dropout, Early Stopping, Regulatrized Learning rate (Learning Rate Scheduler) and Data Augmentation\n",
    "\n",
    "#Select model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "0a71a3c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the pre-trained ResNet50 model without the top classification layers\n",
    "base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the ResNet50 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "550ef907",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import ResNet152\n",
    "\n",
    "# Load the pre-trained ResNet152 model without the top classification layers\n",
    "base_model = ResNet152(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the Resnet152 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "af14ba29",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.xception import Xception\n",
    "\n",
    "# Load the pre-trained Xception model without the top classification layers\n",
    "base_model = Xception(weights='imagenet', include_top=False, input_shape=(299, 299, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the Xception base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "cd692445",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import EfficientNetV2L # cambiar L por M.\n",
    "# Load the pre-trained EfficientNetV2L model without the top classification layers\n",
    "base_model = EfficientNetV2L(weights='imagenet', include_top=False, input_shape=(299, 299, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the EfficientNetV2L base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "13b4b14c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import DenseNet201\n",
    "\n",
    "# Load the pre-trained DenseNet201 model without the top classification layers\n",
    "base_model = DenseNet201(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the DenseNet201 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "223daec5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "968f985e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#YOU CAN USE A SIMPLE ARCHITECTURE OR A COMPLEX ONE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "11677392",
   "metadata": {},
   "outputs": [],
   "source": [
    "#SIMPLE ARCHITECTURE\n",
    "\n",
    "#Simple model\n",
    "\n",
    "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "# Set random seed for TensorFlow\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "# Set random seed for numpy\n",
    "np.random.seed(42)\n",
    "\n",
    "# Set random seed for Python random module\n",
    "random.seed(42)\n",
    "\n",
    "# Ensure TensorFlow runs deterministically\n",
    "os.environ['TF_DETERMINISTIC_OPS'] = '1'\n",
    "os.environ['PYTHONHASHSEED'] = '42'\n",
    "\n",
    "# Now your entire script will produce reproducible results as long as it doesn't rely on non-deterministic operations\n",
    "\n",
    "# Create FSSL-MAML model\n",
    "def create_maml_model(base_model, num_classes):\n",
    "    maml_model = models.Sequential([\n",
    "        base_model,\n",
    "        layers.Conv2D(512, (3, 3), activation='relu'),\n",
    "        layers.GlobalAveragePooling2D(),  # Replace max-pooling with GlobalAveragePooling\n",
    "        layers.Dense(512, activation='relu'),\n",
    "        layers.Dropout(0.5),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Dense(num_classes, activation='softmax')\n",
    "    ])\n",
    "    return maml_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b53847",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "3fb6a68b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#COMPLEX ARCHITECTURE:\n",
    "\n",
    "#RESIDUAL BLOCKS\n",
    "#change number of filters for first residual block of MAML model.\n",
    "#Resnet architectures= 2048; Densenet= 1920; Xception= 2048; EfficienNet= 1280)\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, models\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "# Set random seed for TensorFlow\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "# Set random seed for numpy\n",
    "np.random.seed(42)\n",
    "\n",
    "# Set random seed for Python random module\n",
    "random.seed(42)\n",
    "\n",
    "# SE block for channel-wise attention\n",
    "def squeeze_excite_block(input_tensor, ratio=16):\n",
    "    filters = input_tensor.shape[-1]\n",
    "    se = layers.GlobalAveragePooling2D()(input_tensor)\n",
    "    se = layers.Reshape((1, 1, filters))(se)\n",
    "    se = layers.Dense(filters // ratio, activation='relu')(se)\n",
    "    se = layers.Dense(filters, activation='sigmoid')(se)\n",
    "    return layers.multiply([input_tensor, se])\n",
    "\n",
    "# Residual block with Depthwise Separable Convolution and SE attention\n",
    "def residual_block(x, filters, kernel_size=3, stride=1):\n",
    "    shortcut = x\n",
    "    if x.shape[-1] != filters:\n",
    "        shortcut = layers.Conv2D(filters, (1, 1), strides=stride, padding='same')(shortcut)\n",
    "\n",
    "    x = layers.SeparableConv2D(filters, kernel_size, strides=stride, padding='same', activation='relu')(x)\n",
    "    x = layers.SeparableConv2D(filters, kernel_size, strides=1, padding='same')(x)\n",
    "    x = squeeze_excite_block(x)  # Add SE block for attention\n",
    "    x = layers.Add()([x, shortcut])  # Add shortcut connection\n",
    "    return layers.Activation('relu')(x)\n",
    "\n",
    "# Modified model creation with base model, SE blocks, and depthwise separable convolutions\n",
    "def create_maml_model(base_model, num_classes):\n",
    "    x = base_model.output\n",
    "    x = residual_block(x, 2048)  # First residual block with matching filters. match the output shape from the base model (assuming a ResNet variant).\n",
    "    x = residual_block(x, 512)   # Second residual block with fewer filters\n",
    "    x = layers.GlobalAveragePooling2D()(x)\n",
    "    x = layers.Dense(512, activation='relu')(x)\n",
    "    x = layers.Dropout(0.5)(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "    outputs = layers.Dense(num_classes, activation='softmax')(x)\n",
    "\n",
    "    maml_model = models.Model(inputs=base_model.input, outputs=outputs)\n",
    "    return maml_model\n",
    "\n",
    "#RESIDUAL BLOCKS\n",
    "#change number of filters for first residual block of MAML model.\n",
    "#Resnet architectures= 2048; Densenet= 1920; Xception= 2048; EfficienNet= 1280)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0891a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#RUN THE MODEL AND ARCHITECTURE SELECTED"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6db813-969b-42af-9029-0ed14ca4abc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FEW-SHOT SUPERVISED LEARNING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3a4885e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FSSL (we keep the MAML name for shared code´s sake)\n",
    "# Set up FSSL model\n",
    "num_classes = 4  # Number of output classes\n",
    "maml_model = create_maml_model(base_model, num_classes)\n",
    "\n",
    "# Set up FSSL optimizer\n",
    "meta_optimizer = optimizers.Adam(learning_rate=1e-3)\n",
    "#meta_optimizer = optimizers.Adagrad(learning_rate=0.001)\n",
    "#meta_optimizer = optimizers.SGD(learning_rate=0.001, momentum=0.9)\n",
    "\n",
    "# Compile the FSSL model\n",
    "maml_model.compile(optimizer=meta_optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "\n",
    "\n",
    "# Lists to store training, validation, and testing history\n",
    "train_loss_history = []\n",
    "train_acc_history = []\n",
    "val_loss_history = []\n",
    "val_acc_history = []\n",
    "test_loss_history = []\n",
    "test_acc_history = []\n",
    "\n",
    "# Training loop for few-shot learning\n",
    "num_epochs = 100  # Increase the number of epochs\n",
    "num_tasks = 40  # Number of few-shot tasks\n",
    "validation_interval = 1  # Evaluate on the validation set every 'validation_interval' epochs\n",
    "num_shots = 5 #Number of examples per class in each few-shot task\n",
    "\n",
    "# Early stopping parameters\n",
    "patience = 15\n",
    "best_val_loss = float('inf')\n",
    "wait = 0\n",
    "\n",
    "# Data Augmentation\n",
    "datagen = ImageDataGenerator(\n",
    "    rotation_range=20,\n",
    "    width_shift_range=0.2,\n",
    "    height_shift_range=0.2,\n",
    "    horizontal_flip=True\n",
    ")\n",
    "# Apply seed for reproducibility\n",
    "datagen.seed = 42\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = tf.keras.callbacks.EarlyStopping(\n",
    "    monitor='val_loss',\n",
    "    patience=patience,\n",
    "    restore_best_weights=True\n",
    ")\n",
    "# Define the directory where you want to save the best model\n",
    "save_dir = 'C:\\\\Users\\\\oscar\\\\Documents\\\\experimentos\\\\2025_Four_carnivores_MAML\\\\bestmodels'\n",
    "\n",
    "# Callback to save the best model\n",
    "model_checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
    "    os.path.join(save_dir, 'best_Resn_model3_10.keras'),  # Full path to save the best model\n",
    "    monitor='val_loss',\n",
    "    save_best_only=True\n",
    ")\n",
    "\n",
    "\n",
    "# Learning rate scheduler function\n",
    "def lr_schedule(epoch):\n",
    "    lr = 1e-3\n",
    "    if epoch > 10:\n",
    "        lr *= 0.1\n",
    "    if epoch > 20:\n",
    "        lr *= 0.1\n",
    "    if epoch > 30:\n",
    "        lr *= 0.1\n",
    "    return lr\n",
    "\n",
    "# Training loop\n",
    "for epoch in range(num_epochs):\n",
    "    for task in range(num_tasks):\n",
    "        task_samples = []\n",
    "        task_labels = []\n",
    "        for class_idx in range(num_classes):\n",
    "            class_indices = np.where(train_labels == class_idx)[0]\n",
    "            selected_indices = np.random.choice(class_indices, num_shots, replace=False)\n",
    "            selected_samples = train_images[selected_indices]\n",
    "            task_samples.extend(selected_samples)\n",
    "            task_labels.extend([class_idx] * num_shots)\n",
    "\n",
    "        task_samples = np.array(task_samples)\n",
    "        task_labels = np.array(task_labels)\n",
    "\n",
    "        # Fine-tune the model on the few-shot task with data augmentation\n",
    "        with tf.GradientTape() as tape:\n",
    "            augmented_task_samples = []\n",
    "            for sample in task_samples:\n",
    "                augmented_sample = datagen.random_transform(sample)\n",
    "                augmented_task_samples.append(augmented_sample)\n",
    "            augmented_task_samples = np.array(augmented_task_samples)\n",
    "\n",
    "            logits = maml_model(augmented_task_samples)\n",
    "            loss = tf.losses.sparse_categorical_crossentropy(task_labels, logits)\n",
    "\n",
    "        gradients = tape.gradient(loss, maml_model.trainable_variables)\n",
    "        meta_optimizer.apply_gradients(zip(gradients, maml_model.trainable_variables))\n",
    "\n",
    "    train_loss, train_acc = maml_model.evaluate(train_images, train_labels, verbose=0)\n",
    "    print(f'Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc * 100:.2f}%')\n",
    "\n",
    "    train_loss_history.append(train_loss)\n",
    "    train_acc_history.append(train_acc)\n",
    "\n",
    "    if epoch % validation_interval == 0:\n",
    "        val_loss, val_acc = maml_model.evaluate(val_images, val_labels, verbose=0)\n",
    "        print(f'Epoch {epoch + 1}/{num_epochs} - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc * 100:.2f}%')\n",
    "        \n",
    "        val_loss_history.append(val_loss)\n",
    "        val_acc_history.append(val_acc)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_val_loss = val_loss\n",
    "            wait = 0\n",
    "            # Save the best model\n",
    "            #maml_model.save(os.path.join(save_dir, 'best_resnet152.h5'))#change model name here\n",
    "            maml_model.save(os.path.join(save_dir, 'best_efficient_final.h5'), save_format='h5', include_optimizer=True)\n",
    "            maml_model.save_weights(os.path.join(save_dir, 'best_efficient_final_weights.h5'))\n",
    "        else:\n",
    "            wait += 1\n",
    "            if wait >= patience:\n",
    "                print(f'Early stopping at epoch {epoch + 1}')\n",
    "                break\n",
    "\n",
    "        # Apply learning rate scheduler\n",
    "    lr = lr_schedule(epoch)\n",
    "    tf.keras.backend.set_value(meta_optimizer.lr, lr)\n",
    "    print(f'Learning Rate: {lr}')\n",
    "\n",
    "# Define path for saving the last epoch model\n",
    "last_model_path = os.path.join(save_dir, 'last_epoch_efficient.h5')\n",
    "# Save the model at the last epoch\n",
    "maml_model.save(last_model_path, save_format='h5', include_optimizer=True)\n",
    "maml_model.save_weights(os.path.join(save_dir, 'last_epoch_efficient_weights.h5'))\n",
    "\n",
    "# After training, you can use the test set for the final evaluation\n",
    "test_loss, test_acc = maml_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print(f'Final Testing Loss: {test_loss:.4f}, Testing Accuracy: {test_acc * 100:.2f}%')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61d3c95f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "# Plotting the training and validation history\n",
    "plt.figure(figsize=(12, 4))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(train_loss_history, label='Train Loss')\n",
    "plt.plot(val_loss_history, label='Validation Loss')\n",
    "plt.title('Training and Validation Loss')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(train_acc_history, label='Train Accuracy')\n",
    "plt.plot(val_acc_history, label='Validation Accuracy')\n",
    "plt.title('Training and Validation Accuracy')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c627e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#classification report for the last model\n",
    "from sklearn.metrics import classification_report\n",
    "\n",
    "# After training, you can use the test set for the final evaluation\n",
    "test_loss, test_acc = maml_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print(f'Final Testing Loss: {test_loss:.4f}, Testing Accuracy: {test_acc * 100:.2f}%')\n",
    "\n",
    "# Predict probabilities for test images\n",
    "y_pred_prob = maml_model.predict(test_images)\n",
    "\n",
    "# Convert probabilities to class labels\n",
    "y_pred = np.argmax(y_pred_prob, axis=1)\n",
    "\n",
    "# Generate classification report\n",
    "print(classification_report(test_labels, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c64b438",
   "metadata": {},
   "outputs": [],
   "source": [
    "#CHECK BEST AND LAST MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c3fc2bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "\n",
    "# Define paths\n",
    "best_model_path = os.path.join(save_dir, 'best_efficient_final.h5')\n",
    "last_model_path = os.path.join(save_dir, 'last_epoch_efficient.h5')\n",
    "\n",
    "# Load models\n",
    "best_model = tf.keras.models.load_model(best_model_path)\n",
    "last_model = tf.keras.models.load_model(last_model_path)\n",
    "\n",
    "# Evaluate models on test data\n",
    "best_loss, best_acc = best_model.evaluate(test_images, test_labels, verbose=0)\n",
    "last_loss, last_acc = last_model.evaluate(test_images, test_labels, verbose=0)\n",
    "\n",
    "# Print results\n",
    "print(f'🔹 Best Model - Test Loss: {best_loss:.4f}, Test Accuracy: {best_acc * 100:.2f}%')\n",
    "print(f'🔹 Last Epoch Model - Test Loss: {last_loss:.4f}, Test Accuracy: {last_acc * 100:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6efac16",
   "metadata": {},
   "outputs": [],
   "source": [
    "#CLASSIFICATION REPORT FOR BEST MODEL, NOT FOR THE LAST MODEL\n",
    "\n",
    "# After training, you can use the test set for the final evaluation\n",
    "test_loss, test_acc = best_model.evaluate(test_images, test_labels, verbose=0)\n",
    "print(f'Final Testing Loss: {test_loss:.4f}, Testing Accuracy: {test_acc * 100:.2f}%')\n",
    "\n",
    "# Predict probabilities for test images\n",
    "y_pred_prob = best_model.predict(test_images)\n",
    "\n",
    "# Convert probabilities to class labels\n",
    "y_pred = np.argmax(y_pred_prob, axis=1)\n",
    "\n",
    "# Generate and print the classification report\n",
    "print(\"🔹 Classification Report for Best Model:\")\n",
    "print(classification_report(test_labels, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b86bba1-bb95-4526-982d-6b1211dbc0c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c406cc98-fc70-4e96-ad11-0ed382a86fb1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e687191-916b-48c9-8f04-edf2b9fa9637",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b75aa69a-4198-4787-a1ad-528cc9d52503",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd0e447",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa205569",
   "metadata": {},
   "outputs": [],
   "source": [
    "#MODEL-AGNOSTIC META-LEARNING (MAML)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eab82ad-059a-4914-979b-b3bb08a29182",
   "metadata": {},
   "outputs": [],
   "source": [
    "#FIRST, WE IMPLEMENT A COLOR AUGMENTATION FUNCTION ADAPTED TO SELECTED MODEL (GO UP TO THE BEGINNING OF THE CODE)\n",
    "#SECOND, LOAD THE IMAGES AGAIN (GO UP TO THE BEGINNING OF THE CODE)\n",
    "#THIRD, CHOOSE MODEL BELOW:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52404cd1-4164-49bb-aa1b-e7e17038eb4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#SELECT MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0294411-9474-4da5-8b6a-40595cf85837",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the pre-trained ResNet50 model without the top classification layers\n",
    "base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the ResNet50 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Unfreeze last few ResNet50 layers for better adaptation\n",
    "for layer in base_model.layers[-10:]:\n",
    "    layer.trainable = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1d56025c-188c-4b93-bfcc-b237bd59b8e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import ResNet152\n",
    "\n",
    "# Load the pre-trained ResNet152 model without the top classification layers\n",
    "base_model = ResNet152(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the Resnet152 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Unfreeze last few ResNet152 layers for better adaptation\n",
    "for layer in base_model.layers[-10:]:\n",
    "    layer.trainable = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "46b8e762-5472-4ac6-9cad-cb3e791253d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.xception import Xception\n",
    "\n",
    "# Load the pre-trained Xception model without the top classification layers\n",
    "base_model = Xception(weights='imagenet', include_top=False, input_shape=(299, 299, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the Xception base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "    \n",
    "# Unfreeze last few Xception layers for better adaptation\n",
    "for layer in base_model.layers[-10:]:\n",
    "    layer.trainable = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c55d3456-89ed-4805-8e1d-90cc6134341c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import EfficientNetV2L # cambiar L por M.\n",
    "\n",
    "# Load the pre-trained EfficientNetV2L model without the top classification layers\n",
    "base_model = EfficientNetV2L(weights='imagenet', include_top=False, input_shape=(299, 299, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the EfficientNetV2L base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Unfreeze last few EfficientNetV2L layers for better adaptation\n",
    "for layer in base_model.layers[-10:]:\n",
    "    layer.trainable = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "97ca9c76-8d19-489a-986b-bcb3db96d785",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications import DenseNet201\n",
    "\n",
    "# Load the pre-trained DenseNet201 model without the top classification layers\n",
    "base_model = DenseNet201(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "# Freeze the convolutional layers of the DenseNet201 base model\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "\n",
    "# Unfreeze last few DenseNet201 layers for better adaptation\n",
    "for layer in base_model.layers[-10:]:\n",
    "    layer.trainable = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c8e5be-e7c6-481a-aa03-eb531b076b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#NOW RUN THE MAML MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94507a45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1/100\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, models, optimizers\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "tf.random.set_seed(42)\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "\n",
    "\n",
    "def create_maml_model(base_model, num_classes):\n",
    "    maml_model = models.Sequential([\n",
    "        base_model,\n",
    "        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),\n",
    "        layers.GlobalAveragePooling2D(),\n",
    "        layers.Dense(512, activation='relu'),\n",
    "        layers.Dropout(0.5),\n",
    "        layers.BatchNormalization(),\n",
    "        layers.Dense(num_classes, activation='softmax')\n",
    "    ])\n",
    "    return maml_model\n",
    "\n",
    "# MAML hyperparameters\n",
    "num_classes = 4\n",
    "n_way = num_classes\n",
    "k_shot = 5\n",
    "q_query = 3\n",
    "num_tasks = 40\n",
    "num_epochs = 100\n",
    "alpha = 0.01\n",
    "beta = 0.001\n",
    "inner_steps = 5\n",
    "\n",
    "# Instantiate model and optimizer\n",
    "maml_model = create_maml_model(base_model, num_classes)\n",
    "meta_optimizer = tf.keras.optimizers.Adam(learning_rate=beta)\n",
    "#meta_optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.001)\n",
    "\n",
    "# Paths for saving\n",
    "save_dir = 'C:\\\\Users\\\\oscar\\\\Documents\\\\experimentos\\\\2025_Four_carnivores_MAML\\\\bestmodels\\\\bestmodelsMAML'\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "# Data augmentation\n",
    "datagen = ImageDataGenerator(\n",
    "    rotation_range=20,\n",
    "    width_shift_range=0.2,\n",
    "    height_shift_range=0.2,\n",
    "    horizontal_flip=True\n",
    ")\n",
    "\n",
    "def maml_task_step(model, support_images, support_labels, query_images, query_labels):\n",
    "    with tf.GradientTape() as meta_tape:\n",
    "        # Clone the weights from the model for inner loop\n",
    "        fast_weights = model.trainable_variables\n",
    "\n",
    "        # Inner loop adaptation\n",
    "        for _ in range(inner_steps):\n",
    "            with tf.GradientTape() as tape:\n",
    "                support_preds = model(support_images, training=True)\n",
    "                support_loss = tf.keras.losses.sparse_categorical_crossentropy(support_labels, support_preds)\n",
    "                support_loss = tf.reduce_mean(support_loss)\n",
    "            grads = tape.gradient(support_loss, fast_weights)\n",
    "            fast_weights = [w - alpha * g if g is not None else w for w, g in zip(fast_weights, grads)]\n",
    "\n",
    "        # Compute query predictions using the adapted weights\n",
    "        def forward_with_weights(x, weights):\n",
    "            idx = 0\n",
    "            for layer in model.layers:\n",
    "                if hasattr(layer, 'trainable_variables') and layer.trainable_variables:\n",
    "                    n_vars = len(layer.trainable_variables)\n",
    "                    for i in range(n_vars):\n",
    "                        layer.trainable_variables[i].assign(weights[idx])\n",
    "                        idx += 1\n",
    "            return model(x, training=True)\n",
    "\n",
    "        query_preds = forward_with_weights(query_images, fast_weights)\n",
    "        query_loss = tf.keras.losses.sparse_categorical_crossentropy(query_labels, query_preds)\n",
    "        query_loss = tf.reduce_mean(query_loss)\n",
    "\n",
    "    meta_grads = meta_tape.gradient(query_loss, model.trainable_variables)\n",
    "    return meta_grads, query_loss\n",
    "\n",
    "# Track loss and accuracy for plotting\n",
    "train_loss_history = []\n",
    "val_loss_history = []\n",
    "train_acc_history = []\n",
    "val_acc_history = []\n",
    "\n",
    "# MAML training loop\n",
    "best_val_loss = float('inf')\n",
    "patience = 15\n",
    "wait = 0\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    print(f\"\\nEpoch {epoch + 1}/{num_epochs}\")\n",
    "    meta_grads_accum = [tf.zeros_like(var) for var in maml_model.trainable_variables]\n",
    "    epoch_loss = []\n",
    "\n",
    "    for task in range(num_tasks):\n",
    "        support_images, support_labels, query_images, query_labels = [], [], [], []\n",
    "\n",
    "        for class_idx in range(n_way):\n",
    "            class_indices = np.where(train_labels == class_idx)[0]\n",
    "            selected_indices = np.random.choice(class_indices, k_shot + q_query, replace=False)\n",
    "            support_idx = selected_indices[:k_shot]\n",
    "            query_idx = selected_indices[k_shot:]\n",
    "\n",
    "            support_images.extend(train_images[support_idx])\n",
    "            support_labels.extend([class_idx] * k_shot)\n",
    "            query_images.extend(train_images[query_idx])\n",
    "            query_labels.extend([class_idx] * q_query)\n",
    "\n",
    "        support_images = tf.convert_to_tensor(np.array(support_images), dtype=tf.float32)\n",
    "        support_labels = tf.convert_to_tensor(np.array(support_labels), dtype=tf.int32)\n",
    "        query_images = tf.convert_to_tensor(np.array(query_images), dtype=tf.float32)\n",
    "        query_labels = tf.convert_to_tensor(np.array(query_labels), dtype=tf.int32)\n",
    "\n",
    "        meta_grads, loss = maml_task_step(\n",
    "            maml_model, support_images, support_labels, query_images, query_labels\n",
    "        )\n",
    "\n",
    "        meta_grads_accum = [\n",
    "            acc + (g if g is not None else tf.zeros_like(acc))\n",
    "            for acc, g in zip(meta_grads_accum, meta_grads)\n",
    "        ]\n",
    "        epoch_loss.append(loss.numpy())\n",
    "\n",
    "    mean_grads = [g / num_tasks for g in meta_grads_accum]\n",
    "    meta_optimizer.apply_gradients(zip(mean_grads, maml_model.trainable_variables))\n",
    "\n",
    "    mean_query_loss = np.mean(epoch_loss)\n",
    "    train_loss_history.append(mean_query_loss)\n",
    "\n",
    "    train_preds = maml_model.predict(train_images)\n",
    "    train_acc = np.mean(np.argmax(train_preds, axis=1) == train_labels)\n",
    "    train_acc_history.append(train_acc)\n",
    "\n",
    "    print(f\"Epoch {epoch + 1}: Mean Query Loss: {mean_query_loss:.4f}\")\n",
    "\n",
    "    val_preds = maml_model.predict(val_images)\n",
    "    val_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(val_labels, val_preds)).numpy()\n",
    "    val_acc = np.mean(np.argmax(val_preds, axis=1) == val_labels)\n",
    "\n",
    "    val_loss_history.append(val_loss)\n",
    "    val_acc_history.append(val_acc)\n",
    "\n",
    "    print(f\"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc * 100:.2f}%\")\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        wait = 0\n",
    "        maml_model.save(os.path.join(save_dir, 'best_model.keras'))\n",
    "    else:\n",
    "        wait += 1\n",
    "        if wait >= patience:\n",
    "            print(f\"Early stopping at epoch {epoch + 1}\")\n",
    "            break\n",
    "\n",
    "# Final test\n",
    "test_preds = maml_model.predict(test_images)\n",
    "test_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(test_labels, test_preds)).numpy()\n",
    "test_acc = np.mean(np.argmax(test_preds, axis=1) == test_labels)\n",
    "print(f'Final Test Loss: {test_loss:.4f}, Accuracy: {test_acc * 100:.2f}%')\n",
    "\n",
    "# Plotting the training and validation history\n",
    "plt.figure(figsize=(12, 4))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(train_loss_history, label='Train Loss', color='blue')\n",
    "plt.plot(val_loss_history, label='Validation Loss', color='orange')\n",
    "plt.title('Training and Validation Loss')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(train_acc_history, label='Train Accuracy', color='blue')\n",
    "plt.plot(val_acc_history, label='Validation Accuracy', color='orange')\n",
    "plt.title('Training and Validation Accuracy')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a38643f-9a44-43f4-9f31-9ffd9bcf1d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "#confusion matrix\n",
    "from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay\n",
    "\n",
    "test_preds = maml_model.predict(test_images)\n",
    "test_pred_labels = np.argmax(test_preds, axis=1)\n",
    "\n",
    "test_loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(test_labels, test_preds)).numpy()\n",
    "test_acc = np.mean(test_pred_labels == test_labels)\n",
    "print(f'Final Test Loss: {test_loss:.4f}, Accuracy: {test_acc * 100:.2f}%')\n",
    "\n",
    "# Classification report\n",
    "from sklearn.metrics import classification_report\n",
    "print(\"\\nClassification Report:\")\n",
    "print(classification_report(test_labels, test_pred_labels, target_names=['crocodiles', 'hyenas', 'leopards','lions']))\n",
    "\n",
    "# Confusion matrix\n",
    "cm = confusion_matrix(test_labels, test_pred_labels)\n",
    "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['crocodiles', 'hyenas', 'leopards','lions'])\n",
    "disp.plot(cmap='Blues', values_format='d')\n",
    "plt.title(\"Confusion Matrix\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e43bd6b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "906acc1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import load_model\n",
    "\n",
    "model_path = r\"C:\\Users\\oscar\\Documents\\experimentos\\2025_Four_carnivores_MAML\\bestmodels\\bestmodelsMAML\\resnet50.keras\"\n",
    "churri_model = load_model(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0547d4e4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f278df62",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75fdb63b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
